#!/usr/bin/env python
# -*- coding:utf-8 -*-
import sys

sys.path.append('..')
from common_utils import CBBCommonUtils


def get_plugin_info():
    plugin_info = {
        "name": "centos_authentication_006 Set_continuous_authentication_failure_lock_account",
        "plugin_id": "centos_authentication_006",
        "plugin_type": "Authentication",
        "info": "When consecutive authentication failures of a user exceeds 6 times, lock the account.",
        "level": "A",
        "module": "Safety reinforcement",
        "author": "congshz",
        "keyword": "Safety reinforcement",
        "configable": "false"
    }
    return plugin_info


logger = None
cur_user = None
cur_module = None
cur_task = None


def set_plugin_logger(setter, user, module, task, *args, **kwargs):
    global logger, cur_user, cur_module, cur_task
    logger = setter
    cur_user = user
    cur_module = module
    cur_task = task


def scan_file(ip, cmd, filename, sys_user, sys_pwd):
    des_list = []
    error_count = 0
    result, output = CBBCommonUtils.cbb_run_cmd(ip, cmd, username=sys_user, passwd=sys_pwd)
    if not result and len(output) != 0:
        des = "{} is SAFE.".format(filename)
        des_list.append(des)
        logger.debug_info(cur_user, cur_module, cur_task+'_scan_file', '', des)
    elif not result and len(output) == 0:
        des = "{} is DANGEROUS.".format(filename)
        error_count += 1
        des_list.append(des)
        logger.debug_warning(cur_user, cur_module, cur_task + '_scan_file', '', des)
    else:
        error_des = "CMD process ERROR. Error info: {}".format(result[0].replace("\n", ""))
        error_count += 1
        des_list.append(error_des)
        logger.debug_error(cur_user, cur_module, cur_task+'_scan_file', '', error_des)
    return des_list, error_count


# 扫描函数
def scan(ip, sys_user, sys_pwd, flag=0, reinforce_flag=False):
    # 获取配置文件内容
    cmd_auth = "cat /etc/pam.d/system-auth | grep deny="
    des_auth, error_auth = scan_file(ip, cmd_auth, "auth_fail_locking", sys_user, sys_pwd)

    account_cmd = "cat /etc/pam.d/system-auth | grep account | grep required | grep pam_tally.so"
    des_account, error_account = scan_file(ip, account_cmd, "account_auth", sys_user, sys_pwd)

    ssh_cmd = "cat /etc/pam.d/sshd | grep deny | grep unlock_time"
    des_ssh, error_ssh = scan_file(ip, ssh_cmd, "ssh_auth", sys_user, sys_pwd)

    ssh_auth_cmd = "cat /etc/pam.d/sshd | grep account | grep required | grep pam_tally2.so"
    des_ssh_auth, error_ssh_auth = scan_file(ip, ssh_auth_cmd, "ssh_auth", sys_user, sys_pwd)

    des_list = des_auth + des_account + des_ssh + des_ssh_auth
    error_count = error_auth + error_account + error_ssh + error_ssh_auth
    if not reinforce_flag:
        des_list.append("Scan Complete.")
        logger.debug_info(cur_user, cur_module, cur_task+'_Scan', '', 'Scan Complete.')
        if error_count != 0:
            # 是否加固
            if flag == 1:
                reinforce_des, reinforce_err = reinforce(ip, sys_user, sys_pwd)
                des_list.extend(reinforce_des)
                if reinforce_err == 0:
                    error_count = 0
    return des_list, error_count


# 加固函数
def reinforce(ip, sys_user, sys_pwd):
    des_list = []
    error_count = 0
    # 备份文件，文件名中的D为DENY缩写，为了针对功能点控制备份文件用，避免备份冲突
    cmd = "cp /etc/pam.d/system-auth /etc/pam.d/system-auth.D.`date +%Y-%m-%d`"
    cmd_ssh = "cp /etc/pam.d/sshd /etc/pam.d/sshd.D.`date +%Y-%m-%d`"
    result, output = CBBCommonUtils.cbb_run_cmd(ip, cmd, username=sys_user, passwd=sys_pwd)
    result1, output1 = CBBCommonUtils.cbb_run_cmd(ip, cmd_ssh, username=sys_user, passwd=sys_pwd)
    if not result and not result1:
        # 检查是否备份成功
        check_cmd = "ls /etc/pam.d/ | grep system-auth.D.`date +%Y-%m-%d`"
        check_cmd1 = "ls /etc/pam.d/ | grep sshd.D.`date +%Y-%m-%d`"
        check_result, check_output = CBBCommonUtils.cbb_run_cmd(ip, check_cmd, username=sys_user, passwd=sys_pwd)
        check_result1, check_output1 = CBBCommonUtils.cbb_run_cmd(ip, check_cmd1, username=sys_user, passwd=sys_pwd)
        if not check_result and not check_result1 and len(check_output) != 0 and len(check_output1) != 0:
            backup_check_des = "Backup SUCCEED."
            logger.debug_info(cur_user, cur_module, cur_task+'_Reinforce', '', backup_check_des)
        else:
            error_count += 1
            backup_check_des = "Backup check FAILED."
            logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', backup_check_des)
            des_list.append(backup_check_des)
            return des_list, error_count
    else:
        error_count += 1
        backup_des = "Backup FAILED."
        logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', backup_des)
        des_list.append(backup_des)
        return des_list, error_count

    # 查找文件中需要更改的行，不硬编码以适应其他情况
    grep_cmd = "cat /etc/pam.d/system-auth | grep auth | grep required | grep pam_tally.so"
    result, output = CBBCommonUtils.cbb_run_cmd(ip, grep_cmd, username=sys_user, passwd=sys_pwd)
    if not result and len(output) != 0:
        # 组装命令
        safe_cmd = "sed -i '/pam_tally.so/cauthrequired      pam_tally.so deny=5 unlock_time=300' \"/etc/pam.d/system-auth\""
        CBBCommonUtils.cbb_run_cmd(ip, safe_cmd, username=sys_user, passwd=sys_pwd)
    elif not result and len(output) == 0:
        safe_cmd = "echo 'auth        required      pam_tally.so deny=5 unlock_time=300'>>/etc/pam.d/system-auth"
        CBBCommonUtils.cbb_run_cmd(ip, safe_cmd, username=sys_user, passwd=sys_pwd)
    else:
        error_count += 1
        cat_error = "Get target ERROR.Error info: {}.".format(result[0].replace("\n", ""))
        logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', cat_error)
        des_list.append(cat_error)

    grep_account = "cat /etc/pam.d/system-auth | grep account | grep required | grep pam_tally.so"
    result_a, output_a = CBBCommonUtils.cbb_run_cmd(ip, grep_account, username=sys_user, passwd=sys_pwd)
    if not result_a and len(output_a) == 0:
        account_cmd = "echo 'account     required      pam_tally.so'>>/etc/pam.d/system-auth"
        CBBCommonUtils.cbb_run_cmd(ip, account_cmd, username=sys_user, passwd=sys_pwd)
    else:
        error_count += 1
        cat_error = "Get target ERROR.Error info: {}.".format(result[0].replace("\n", ""))
        logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', cat_error)
        des_list.append(cat_error)

    grep_ssh = "cat /etc/pam.d/sshd | grep deny | grep unlock_time"
    result_s, output_s = CBBCommonUtils.cbb_run_cmd(ip, grep_ssh, username=sys_user, passwd=sys_pwd)
    if not result_s and len(output_s) == 0:
        ssh_cmd = "echo 'auth required pam_tally2.so deny=5 unlock_time=600' >> /etc/pam.d/sshd"
        CBBCommonUtils.cbb_run_cmd(ip, ssh_cmd, username=sys_user, passwd=sys_pwd)
    else:
        error_count += 1
        cat_error = "Get target ERROR.Error info: {}.".format(result[0].replace("\n", ""))
        logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', cat_error)
        des_list.append(cat_error)

    grep_ssh_auth = "cat /etc/pam.d/sshd | grep account | grep required | grep pam_tally2.so"
    result_sa, output_sa = CBBCommonUtils.cbb_run_cmd(ip, grep_ssh_auth, username=sys_user, passwd=sys_pwd)
    if not result_sa and len(output_sa) == 0:
        ssh_auth_cmd = "echo 'account required pam_tally2.so' >> /etc/pam.d/sshd"
        CBBCommonUtils.cbb_run_cmd(ip, ssh_auth_cmd, username=sys_user, passwd=sys_pwd)
    else:
        error_count += 1
        cat_error = "Get target ERROR.Error info: {}.".format(result[0].replace("\n", ""))
        logger.debug_error(cur_user, cur_module, cur_task+'_Reinforce', '', cat_error)
        des_list.append(cat_error)

    # 检查是否加固成功
    des_reinforce_scan, error_reinforce_scan = scan(ip, sys_user, sys_pwd, flag=0, reinforce_flag=True)
    des_list.extend(des_reinforce_scan)
    error_count = error_count + error_reinforce_scan
    if error_count != 0:
        # 加固失败，回滚
        des = "Reinforce FAILED."
        des_list.append(des)
        logger.debug_info(cur_user, cur_module, cur_task + '_Reinforce', '', des)
        rollback_des, rollback_err = rollback(ip, sys_user, sys_pwd)
        des_list.extend(rollback_des)
        return des_list, error_count

    des_list.append("Reinforce Complete.")
    logger.debug_info(cur_user, cur_module, cur_task+'_Reinforce', '', "Reinforce Complete.")
    return des_list, error_count


# 回滚函数
def rollback(ip, sys_user, sys_pwd):
    des_list = []
    error_count = 0
    # 查找备份文件，此处的查找支持多备份文件的查找（结果会按时间顺序排序），并默认回滚最新一次备份的内容，可以连续多次回滚
    grep_cmd = "ls /etc/pam.d/ | grep system-auth.D.*"
    grep_ssh_cmd = "ls /etc/pam.d/ | grep sshd.D.*"
    result, output = CBBCommonUtils.cbb_run_cmd(ip, grep_cmd, username=sys_user, passwd=sys_pwd)
    result1, output1 = CBBCommonUtils.cbb_run_cmd(ip, grep_ssh_cmd, username=sys_user, passwd=sys_pwd)
    if not result and not result1 and len(output) != 0 and len(output1) != 0:
        logger.debug_info(cur_user, cur_module, cur_task+'_Rollback', '', "Backup file FOUND.")
        target = output[len(output) - 1].replace("\n", "")
        target_ssh = output1[len(output1) - 1].replace("\n", "")
        cmd = "/bin/cp -rf /etc/pam.d/" + target + " /etc/pam.d/system-auth"
        cmd_ssh = "/bin/cp -rf /etc/pam.d/" + target_ssh + " /etc/pam.d/sshd"
    else:
        error_count = -1
        des_list.append("Backup file NOT FOUND.")
        logger.debug_error(cur_user, cur_module, cur_task+'_Rollback', '', "Backup file not found.")
        return des_list, error_count
    # 回滚文件
    result, output = CBBCommonUtils.cbb_run_cmd(ip, cmd, username=sys_user, passwd=sys_pwd)
    result1, output1 = CBBCommonUtils.cbb_run_cmd(ip, cmd_ssh, username=sys_user, passwd=sys_pwd)
    if not result and not result1:
        logger.debug_info(cur_user, cur_module, cur_task+'_Rollback', '', "Rollback SUCCEED.")
        del_cmd = "rm -rf /etc/pam.d/" + target
        del_ssh_cmd = "rm -rf /etc/pam.d/" + target_ssh
        # 删除对应的备份文件
        result, output = CBBCommonUtils.cbb_run_cmd(ip, del_cmd, username=sys_user, passwd=sys_pwd)
        result1, output1 = CBBCommonUtils.cbb_run_cmd(ip, del_ssh_cmd, username=sys_user, passwd=sys_pwd)
        if not result and not result1:
            del_des = "Backup file deleted. Target is {0}, {1}.".format(target, target_ssh)
            logger.debug_info(cur_user, cur_module, cur_task+'_Rollback', '', del_des)
        else:
            error_count += 1
            del_des = "Backup file delete FAILED."
            logger.debug_error(cur_user, cur_module, cur_task+'_Rollback', '', del_des)
            des_list.append(del_des)
            return des_list, error_count
    else:
        error_count += 1
        des = "Rollback FAILED, please retry."
        logger.debug_error(cur_user, cur_module, cur_task+'_Rollback', '', des)
        des_list.append(des)
        return des_list, error_count
    des_list.append("Rollback Complete.")
    logger.debug_info(cur_user, cur_module, cur_task+'_Rollback', '', "Rollback Complete.")
    return des_list, error_count


def check(ip, *args, **kwargs):
    sys_user = kwargs.get("system_user")
    sys_pwd = kwargs.get("system_pwd")
    comm = kwargs.get("command")
    try:
        des_list = []
        des_start = "centos_authentication_006 Start."
        logger.debug_info(cur_user, cur_module, cur_task+'_Check', '', des_start)
        if comm == 1:
            des_scan, error_scan = scan(ip, sys_user, sys_pwd, flag=0)
            des_list.extend(des_scan)
            step_error = int(error_scan)
        elif comm == 2:
            des_reinforce, error_reinforce = scan(ip, sys_user, sys_pwd, flag=1)
            des_list.extend(des_reinforce)
            step_error = int(error_reinforce)
        elif comm == 3:
            des_rollback, error_rollback = rollback(ip, sys_user, sys_pwd)
            des_list.extend(des_rollback)
            step_error = int(error_rollback)
        else:
            return {"code": 3, "count": 0, "des": ['command must be 1/2/3']}
        des_end = "centos_authentication_006 Complete."
        logger.debug_info(cur_user, cur_module, cur_task+'_Check', '', des_end)
        if step_error == 0:
            code = 0
        elif step_error <= -1:
            code = 2
        else:
            code = 1
        return {"code": code, "count": step_error, "des": des_list}
    except Exception as er:
        code = 1
        des = ["ERROR:", str(er)]
        logger.debug_error(cur_user, cur_module, cur_task+'_Check', '', des)
        return {"code": code, "count": 0, "des": des}

# check(ip="100.2.91.150", system_user="root", system_pwd="admin", command=0, flag=0)
